{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.data.dataset_readers import UniversalDependenciesDatasetReader\n",
    "from allennlp.data.iterators import BucketIterator\n",
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper\n",
    "from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits\n",
    "from allennlp.training.metrics import CategoricalAccuracy\n",
    "from allennlp.training.trainer import Trainer\n",
    "\n",
    "from realworldnlp.predictors import UniversalPOSPredictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LstmTagger(Model):\n",
    "    def __init__(self,\n",
    "                 embedder: TextFieldEmbedder,\n",
    "                 encoder: Seq2SeqEncoder,\n",
    "                 vocab: Vocabulary) -> None:\n",
    "        super().__init__(vocab)\n",
    "        self.embedder = embedder\n",
    "        self.encoder = encoder\n",
    "        self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),\n",
    "                                      out_features=vocab.get_vocab_size('pos'))\n",
    "        self.accuracy = CategoricalAccuracy()\n",
    "\n",
    "    def forward(self,\n",
    "                words: Dict[str, torch.Tensor],\n",
    "                pos_tags: torch.Tensor = None,\n",
    "                **args) -> Dict[str, torch.Tensor]:\n",
    "        mask = get_text_field_mask(words)\n",
    "\n",
    "        embeddings = self.embedder(words)\n",
    "        encoder_out = self.encoder(embeddings, mask)\n",
    "        tag_logits = self.linear(encoder_out)\n",
    "\n",
    "        output = {\"tag_logits\": tag_logits}\n",
    "        if pos_tags is not None:\n",
    "            self.accuracy(tag_logits, pos_tags, mask)\n",
    "            output[\"loss\"] = sequence_cross_entropy_with_logits(tag_logits, pos_tags, mask)\n",
    "\n",
    "        return output\n",
    "\n",
    "    def get_metrics(self, reset: bool = False) -> Dict[str, float]:\n",
    "        return {\"accuracy\": self.accuracy.get_metric(reset)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader = UniversalDependenciesDatasetReader()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu')\n",
    "dev_dataset = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary.from_instances(train_dataset + dev_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),\n",
    "                            embedding_dim=EMBEDDING_SIZE)\n",
    "word_embeddings = BasicTextFieldEmbedder({\"tokens\": token_embedding})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = PytorchSeq2SeqWrapper(\n",
    "    torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LstmTagger(word_embeddings, encoder, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BucketIterator(batch_size=16, sorting_keys=[(\"words\", \"num_tokens\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=train_dataset,\n",
    "                  validation_dataset=dev_dataset,\n",
    "                   patience=10,\n",
    "                   num_epochs=10)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = UniversalPOSPredictor(model, reader)\n",
    "tokens = ['The', 'dog', 'ate', 'the', 'apple', '.']\n",
    "logits = predictor.predict(tokens)['tag_logits']\n",
    "tag_ids = np.argmax(logits, axis=-1)\n",
    "\n",
    "[vocab.get_token_from_index(tag_id, 'pos') for tag_id in tag_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
